{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sSVXJYfVvCJB"
      },
      "source": [
        "#### This notebook is created by [Nitin Tiwari](https://linkedin.com/in/tiwari-nitin).\n",
        "\n",
        "#### **Social links:**\n",
        "* [LinkedIn](https://linkedin.com/in/tiwari-nitin)\n",
        "* [GitHub](https://github.com/NSTiwari)\n",
        "* [Twitter](https://x.com/NSTiwari21)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W634WCRovEdo"
      },
      "source": [
        "# Referring Expression Segmentation in videos"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NTGOdrsTvHHw"
      },
      "source": [
        "This notebook guides you to perform referring expression segmentation on videos using [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) and draw the inferences using OpenCV and PIL.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/PaliGemma/[PaliGemma_1]Referring_expression_segmentation_in_videos.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X4YM125TvdjZ"
      },
      "source": [
        "### Get access to PaliGemma\n",
        "\n",
        "Before using PaliGemma for the first time, you must request access to the model through Hugging Face by completing the following steps:\n",
        "\n",
        "1. Log in to [Hugging Face](https://huggingface.co), or create a new Hugging Face account if you don't already have one.\n",
        "2. Go to the [PaliGemma model card](https://huggingface.co/google/paligemma-3b-mix-224) to get access to the model.\n",
        "3. Complete the consent form and accept the terms and conditions.\n",
        "\n",
        "To generate a Hugging Face token, open your [**Settings** page in Hugging Face](https://huggingface.co/settings), choose **Access Tokens** option in the left pane and click **New token**. In the next window that appears, give a name to your token and choose the type as **Write** to get the write access.\n",
        "\n",
        "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Hugging Face token. Store your Hugging Face token under the name `HF_TOKEN`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GmCMot7Gvfcg"
      },
      "source": [
        "### Select the runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the PaliGemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, click the **▾ (Additional connection options)** dropdown menu.\n",
        "1. Select **Change runtime type**.\n",
        "1. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tatYlRwvbDNY"
      },
      "source": [
        "### Step 1: Install libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DeBbVm6pa-Lt"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m119.8/119.8 MB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m309.4/309.4 kB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.6/251.6 kB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m44.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: huggingface_hub[cli] in /usr/local/lib/python3.10/dist-packages (0.23.4)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (3.15.1)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (2023.6.0)\n",
            "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (24.1)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (6.0.1)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (2.31.0)\n",
            "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (4.66.4)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub[cli]) (4.12.2)\n",
            "Collecting InquirerPy==0.3.4 (from huggingface_hub[cli])\n",
            "  Downloading InquirerPy-0.3.4-py3-none-any.whl (67 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.7/67.7 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting pfzy<0.4.0,>=0.3.1 (from InquirerPy==0.3.4->huggingface_hub[cli])\n",
            "  Downloading pfzy-0.3.4-py3-none-any.whl (8.5 kB)\n",
            "Requirement already satisfied: prompt-toolkit<4.0.0,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from InquirerPy==0.3.4->huggingface_hub[cli]) (3.0.47)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub[cli]) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub[cli]) (3.7)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub[cli]) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub[cli]) (2024.6.2)\n",
            "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit<4.0.0,>=3.0.1->InquirerPy==0.3.4->huggingface_hub[cli]) (0.2.13)\n",
            "Installing collected packages: pfzy, InquirerPy\n",
            "Successfully installed InquirerPy-0.3.4 pfzy-0.3.4\n"
          ]
        }
      ],
      "source": [
        "!pip install bitsandbytes transformers accelerate peft -q\n",
        "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H6gLDltJbTpc"
      },
      "source": [
        "### Step 2: Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NLXxDAQvbVhk"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, PaliGemmaForConditionalGeneration, PaliGemmaProcessor\n",
        "import torch\n",
        "import numpy as np\n",
        "import cv2\n",
        "import os\n",
        "import re\n",
        "import matplotlib.pyplot as plt\n",
        "import sys\n",
        "from PIL import Image, ImageDraw, ImageFont"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ShmfbLUabZHr"
      },
      "source": [
        "### Step 3: Fetch the `big_vision` repository"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6_Cd2CZEba25"
      },
      "outputs": [],
      "source": [
        "if not os.path.exists(\"big_vision_repo\"):\n",
        "  !git clone --quiet --branch=main --depth=1 \\\n",
        "     https://github.com/google-research/big_vision big_vision_repo\n",
        "\n",
        "# Append big_vision code to Python import path.\n",
        "if \"big_vision_repo\" not in sys.path:\n",
        "  sys.path.append(\"big_vision_repo\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HKJruSCIbcGb"
      },
      "source": [
        "### Step 4: Set environment variables for Hugging Face token"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0KDj2WylvoRB"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get('HF_TOKEN')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Lpinx3KIbhMA"
      },
      "source": [
        "### Step 5: Load pre-trained PaliGemma model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Evn_cTrCbjCO"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4eb98e73425a42f09d8166252784f226",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "08b14e775ac645d18522718629928189",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "76e3f5791b014e07bad9fe253bbf8f93",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b9517a658c354afe8207646cdf69ab98",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "5eab03359f0d4272ad925f04f8f0cc5b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d463015567c04193a6bb113b1f134d9b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00003-of-00003.safetensors:   0%|          | 0.00/1.74G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n",
            "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n",
            "`config.hidden_activation` if you want to override this behaviour.\n",
            "See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e8121d270dce4e1ca2c26f13f62d5b14",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9103dc9bade54fc4bb6216c85d047787",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "834d0729774b4ba3a1ef6c434636432b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "preprocessor_config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "005cb841ed0e41639526fe0cec546c4f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "42fcb0fdeefd422a994ad0ebc83fe02a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fb083b4ac2d54f6391d953e1717181cc",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0fadbf6b29fb4a47b4a2cbc75fd73b03",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b62c5e107d704e5392585f0c0862dc27",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model_id = \"google/paligemma-3b-mix-224\"\n",
        "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16)\n",
        "processor = PaliGemmaProcessor.from_pretrained(model_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-0dBGTR4b_YR"
      },
      "source": [
        "### Step 6: Function to draw segmentation mask on videos"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BJH5zjHVcBBX"
      },
      "outputs": [],
      "source": [
        "import big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval\n",
        "reconstruct_masks = segeval.get_reconstruct_masks('oi')\n",
        "\n",
        "def show_segmentation_output(image, segment_masks, image_size, coordinates_list, labels):\n",
        "\n",
        "  height, width = image_size\n",
        "  global label_colors\n",
        "  masked_image = Image.fromarray(np.uint8(image.copy()))\n",
        "\n",
        "  for i, segment_mask in enumerate(segment_masks):\n",
        "    coordinates = coordinates_list[i]\n",
        "    label = labels[i]\n",
        "\n",
        "    color = label_colors.get(label, None)\n",
        "    if color is None:\n",
        "      color = (np.random.randint(256), np.random.randint(256), np.random.randint(256), 128)\n",
        "      label_colors[label] = (np.random.randint(256), np.random.randint(256), np.random.randint(256), 128)\n",
        "\n",
        "    y1, x1, y2, x2 = coordinates[0], coordinates[1], coordinates[2], coordinates[3]\n",
        "    y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))\n",
        "\n",
        "    # Get mask width and height.\n",
        "    w = x2 - x1\n",
        "    h = y2 - y1\n",
        "\n",
        "    # Scale the mask\n",
        "    x_scale = w / 64\n",
        "    y_scale = h / 64\n",
        "\n",
        "    # Create coordinate grids for the new image.\n",
        "    x_coords = np.arange(w)\n",
        "    y_coords = np.arange(h)\n",
        "    x_coords = (x_coords / x_scale).astype(int)\n",
        "    y_coords = (y_coords / y_scale).astype(int)\n",
        "\n",
        "    # Resize segment mask based on scaling factors.\n",
        "    resized_segmend_mask = segment_mask[y_coords[:, np.newaxis], x_coords]\n",
        "\n",
        "    resized_segmend_mask = np.squeeze(resized_segmend_mask)\n",
        "\n",
        "    pil_image = Image.fromarray(np.uint8(image))\n",
        "\n",
        "    mask = Image.new('RGBA', pil_image.size, (0, 0, 0, 0))\n",
        "    draw = ImageDraw.Draw(mask)\n",
        "\n",
        "    # Draw the mask on the image.\n",
        "    for y in range(h):\n",
        "      for x in range(w):\n",
        "        if resized_segmend_mask[y, x] > 0:\n",
        "          draw.point((x, y), fill=label_colors[label])\n",
        "\n",
        "    masked_image.paste(mask, (x1, y1), mask)\n",
        "\n",
        "  masked_output = np.array(masked_image.convert('RGB'))\n",
        "\n",
        "  # Overlay the legend on the image.\n",
        "  legend_y = int(height * 0.03)\n",
        "  legend_box_width = int(width * 0.05)  # Add padding for text\n",
        "  legend_box_height = int(height * 0.04)\n",
        "  for idx, (label, color) in enumerate(label_colors.items()):\n",
        "    legend_entry_x1 = int(width * 0.84)\n",
        "    legend_entry_y1 = legend_y\n",
        "    legend_entry_x2 = legend_entry_x1 + legend_box_width\n",
        "    legend_entry_y2 = legend_y + legend_box_height\n",
        "    cv2.rectangle(masked_output, (legend_entry_x1, legend_entry_y1), (legend_entry_x2, legend_entry_y2), color[:3], -1)\n",
        "\n",
        "    text = label\n",
        "    font_scale = min(1, legend_box_height / 20)\n",
        "    font_thickness = max(1, int(font_scale * 2))  # Adjust font thickness proportionally\n",
        "    text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)[0]\n",
        "    text_x = int(width * 0.90)  # Adjust for padding\n",
        "    text_y = int((legend_entry_y1 + legend_entry_y2)/2 + legend_box_height*0.1)\n",
        "\n",
        "    cv2.putText(masked_output, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), 2)\n",
        "    legend_y += legend_box_height + max(height // 200, 5)\n",
        "\n",
        "  return masked_output\n",
        "\n",
        "def parse_segments(detokenized_output: str) -> tuple[np.ndarray, np.ndarray]:\n",
        "  matches = re.finditer(\n",
        "      '<loc(?P<y0>\\d\\d\\d\\d)><loc(?P<x0>\\d\\d\\d\\d)><loc(?P<y1>\\d\\d\\d\\d)><loc(?P<x1>\\d\\d\\d\\d)>'\n",
        "      + ''.join(f'<seg(?P<s{i}>\\d\\d\\d)>' for i in range(16)),\n",
        "      detokenized_output,\n",
        "  )\n",
        "  boxes, segs = [], []\n",
        "  fmt_box = lambda x: float(x) / 1024.0\n",
        "  for m in matches:\n",
        "    d = m.groupdict()\n",
        "    boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])\n",
        "    segs.append([int(d[f's{i}']) for i in range(16)])\n",
        "\n",
        "  coordinates = boxes[0]\n",
        "  mask = np.array(reconstruct_masks(np.array(segs)))\n",
        "\n",
        "  return coordinates, mask"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aDqVc7zevy96"
      },
      "source": [
        "### Step 7: Configure the input video and text prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "M372oAWXv3gB"
      },
      "outputs": [],
      "source": [
        "input_video = 'input_video.mp4' # @param {type:\"string\"}\n",
        "\n",
        "prompt = \"segment person, mug, book\" # @param {type: \"string\"}\n",
        "prompt = prompt.replace(',', '\\n')\n",
        "\n",
        "output_file = 'segmentation_output_video.avi' # @param {type: \"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5qIdmXoAwI1q"
      },
      "source": [
        "### Step 8: Pass the input video and text prompt to PaliGemma"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pCJ63nSncZsw"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Output video segmentation_output_video.avi saved to disk.\n"
          ]
        }
      ],
      "source": [
        "# Open the input video file.\n",
        "cap = cv2.VideoCapture(input_video)\n",
        "\n",
        "# Define output video codec and file name.\n",
        "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
        "\n",
        "out = cv2.VideoWriter(output_file, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4))))\n",
        "\n",
        "label_colors = {}\n",
        "\n",
        "while(True):\n",
        "    ret, frame = cap.read()\n",
        "    if not ret:\n",
        "        break\n",
        "\n",
        "    # Convert the frame to a PIL Image.\n",
        "    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n",
        "\n",
        "    # Send text prompt and image as input.\n",
        "    inputs = processor(text=prompt, images=img,\n",
        "                      padding=\"longest\", do_convert_rgb=True, return_tensors=\"pt\").to(\"cuda\")\n",
        "    model.to(device)\n",
        "    inputs = inputs.to(dtype=model.dtype)\n",
        "\n",
        "    # Get output.\n",
        "    with torch.no_grad():\n",
        "        output = model.generate(**inputs, max_length=496)\n",
        "\n",
        "    paligemma_response = processor.decode(output[0], skip_special_tokens=True)[len(prompt):].lstrip(\"\\n\")\n",
        "    detections = paligemma_response.split(\" ; \")\n",
        "\n",
        "    # Parse the output bounding box coordinates\n",
        "    coordinates_list = []\n",
        "    labels = []\n",
        "    segment_masks = []\n",
        "\n",
        "    for detection in detections:\n",
        "        detection = detection.split(\" \")\n",
        "        locations, segmentations, label = detection[0], detection[1], detection[2]\n",
        "        paligemma_output = locations + segmentations\n",
        "        bbox, seg_output = parse_segments(paligemma_output)\n",
        "        segment_masks.append(seg_output[0])\n",
        "        coordinates_list.append(bbox)\n",
        "        labels.append(label)\n",
        "\n",
        "    width = img.size[0]\n",
        "    height = img.size[1]\n",
        "\n",
        "    # Draw bounding boxes on the frame\n",
        "    image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)\n",
        "\n",
        "    for coordinates, label in zip(coordinates_list, labels):\n",
        "      output_frame = show_segmentation_output(image, segment_masks, (height, width), coordinates_list, labels)\n",
        "\n",
        "    # Write the frame to the output video\n",
        "    out.write(output_frame)\n",
        "\n",
        "    # Exit on pressing 'q'\n",
        "    if cv2.waitKey(1) & 0xFF == ord('q'):\n",
        "        break\n",
        "\n",
        "# Release the video capture, output video writer, and destroy the window\n",
        "cap.release()\n",
        "out.release()\n",
        "cv2.destroyAllWindows()\n",
        "print(\"Output video \" + output_file + \" saved to disk.\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[PaliGemma_1]Referring_expression_segmentation_in_videos.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
